Skip to content

feat: add Quasar Attention and standalone model implementation#805

Open
troy12x wants to merge 4 commits intofla-org:mainfrom
troy12x:feat/quasar
Open

feat: add Quasar Attention and standalone model implementation#805
troy12x wants to merge 4 commits intofla-org:mainfrom
troy12x:feat/quasar

Conversation

@troy12x
Copy link
Copy Markdown

@troy12x troy12x commented Mar 31, 2026

Pull Request: Add Quasar Attention and Standalone Model Implementation

Summary

This PR introduces Quasar Attention, a highly optimized linear attention variant derived from Kimi Delta Attention (KDA) but featuring significant architectural optimizations and kernel refinements. Quasar achieves superior throughput and memory efficiency, particularly at long context lengths.

This PR includes:

  1. Quasar Attention Triton Kernels: Fused chunk-wise forward and backward kernels in fla/ops/quasar.
  2. QuasarAttention Layer: A standalone attention layer in fla/layers/quasar.py.
  3. Quasar Model: A complete HuggingFace-compatible model implementation in fla/models/quasar, including QuasarConfig, QuasarModel, and QuasarForCausalLM.
  4. Library Integration: Full registration of Quasar components in the fla library root interfaces.

Benchmarks

Quasar demonstrates superior hardware efficiency compared to baseline linear attention architectures.

High-Throughput Performance

Setup: 8x NVIDIA B200, 2B Model, 64k Context Length

Architecture Throughput (Tokens/sec)
Quasar 478,559
Kimi Delta Attention (KDA) 456,163
Gated Delta Attention 447,784

Scaling and Memory Efficiency

Setup: Single NVIDIA B200, 1B Model

Context Length Quasar Throughput KDA Throughput Speedup
16k 123,259 tok/s 105,052 tok/s +17.3%
32k 146,828 tok/s 110,225 tok/s +33.2%

References

Implementation Details

  • Branding: All components follow the quasar nomenclature to prevent symbol collisions with upstream KDA implementations.
  • Independence: The Quasar module is self-contained, including its own recomputed kernels and configuration classes.
  • Compatibility: Supports both standalone Quasar models and hybrid attention configurations within the FLA framework.

Summary by CodeRabbit

  • New Features

    • Introduces Quasar Attention (chunked and fused-recurrent modes) with GPU-optimized kernels, a QuasarAttention layer, and a Quasar model family including config, base model, and causal LM.
    • Adds a compatibility layer to detect and expose optional distributed/tensor APIs.
  • Documentation

    • Adds a Quasar Attention design doc with benchmarks, references, and implementation overview.

This PR introduces Quasar Attention, featuring significant kernel optimizations
for modern GPU architectures.

Key changes:
- Standalone Quasar Attention Triton kernels (fla/ops/quasar)
- QuasarAttention layer integration (fla/layers/quasar.py)
- Full Quasar model suite (fla/models/quasar)
- Benchmarks show Quasar outperforming KDA/GDA in throughput and memory efficiency.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

Walkthrough

Adds a full Quasar Attention feature: Triton forward/backward kernels (chunk, intra-chunk, fused recurrent, gate, WY recompute, forward-substitution), a QuasarAttention layer, HuggingFace-style Quasar model/config/for-causal-lm, a distributed-compat shim, and registers Quasar symbols in package exports.

Changes

Cohort / File(s) Summary
Documentation
PR_DESCRIPTION.md
New design doc describing Quasar Attention feature set, APIs, benchmarks, and references.
Compatibility & Exports
fla/distributed_compat.py, fla/layers/__init__.py, fla/models/__init__.py, fla/models/quasar/__init__.py
Added Torch distributed compatibility shim; registered Quasar exports in layers/models package all.
Layer API
fla/layers/quasar.py
New QuasarAttention nn.Module with chunk/fused_recurrent modes, optional short conv, RoPE handling, per-token beta, caching and padding/unpadding logic.
Model Config & Impl
fla/models/quasar/configuration_quasar.py, fla/models/quasar/modeling_quasar.py
Added QuasarConfig, QuasarModel, QuasarForCausalLM, QuasarPreTrainedModel, QuasarBlock, init/generation/forward and caching conversion logic.
Top-level Quasar ops
fla/ops/quasar/__init__.py
Exports Quasar op entry points (chunk_quasar, fused_recurrent_quasar).
Chunked forward & autograd
fla/ops/quasar/chunk.py
Chunk-wise Quasar forward path and custom autograd Function exposing forward/backward dispatch.
Chunked backward kernels
fla/ops/quasar/chunk_bwd.py
Triton-backed chunked backward kernels and Python launchers for dA/dv and fused WY/dqkb gradients with varlen support.
Intra-chunk kernels & token-parallel
fla/ops/quasar/chunk_intra.py, fla/ops/quasar/chunk_intra_token_parallel.py
Triton kernels for intra-chunk forward/backward, fused inter+solve, token-parallel intra-chunk forward, safe-gate and recompute wiring.
WY recompute & backward
fla/ops/quasar/wy_fast.py
Kernels/wrappers to recompute w/u forward intermediates and prepare WY backward intermediates with tiling and varlen support.
Forward substitution
fla/ops/quasar/forward_substitution.py
Triton forward-substitution kernel and autograd wrapper (backward placeholder).
Fused recurrent forward
fla/ops/quasar/fused_recurrent.py
Fused recurrent Quasar forward Triton kernel, wrapper, and autograd Function (backward unimplemented).
Gate computation
fla/ops/quasar/gate.py
Naive and Triton-optimized Quasar gate (alpha) implementations, autograd-capable function, fast k-accumulation variant, and fused gate entry.

Sequence Diagram

sequenceDiagram
    participant Input as Input (hidden_states / input_ids)
    participant Proj as Projections (q/k/v/(g), conv)
    participant RoPE as RoPE
    participant Norm as L2Norm
    participant Gate as Quasar Gate (beta→alpha)
    participant Kernel as Triton Kernels (chunk / intra / fused_recurrent)
    participant Recomp as WY Recompute / ForwardSubstitution
    participant Model as QuasarModel / QuasarBlock
    participant Output as Output Projection / LM Head

    Input->>Proj: embed & project to q,k,v,(g)
    Proj->>RoPE: apply rotary embeddings (optional)
    RoPE->>Norm: L2-normalize q,k
    Norm->>Gate: compute beta → alpha
    Gate->>Kernel: send alpha, q,k,v,g and cache/state
    Kernel->>Recomp: request w/u, A, or forward-substitution (if needed)
    Recomp->>Kernel: return intermediates (w,u,A,...)
    Kernel->>Model: return attention output (padded back if unpadded)
    Model->>Output: project, apply gated norm, compute logits
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs
  • zhiyuan1i

Poem

🐰 I hop through Triton fields at night, with chunks and gates to make things bright,
I stitch the kernels, bind the thread, and nudge the beta for each head,
Quasar hums — a tiled delight, attention nudges byte by byte. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.74% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add Quasar Attention and standalone model implementation' accurately and concisely describes the primary changes: introducing Quasar Attention (kernels, layers, models) as a complete feature addition.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Quasar Attention, a high-performance linear attention variant, along with its associated Triton kernels, standalone layers, and HuggingFace-compatible model classes. The review identified several critical issues in the implementation of the custom kernels, including inconsistent alpha recomputation in the backward pass, missing gradients for learnable parameters like A_log and dt_bias, and the omission of state-passing gradients. Additionally, feedback was provided regarding hardcoded execution modes that override user preferences, redundant normalization steps, and duplicated code blocks within the Triton kernels.

Comment thread fla/ops/quasar/chunk.py
Comment on lines +204 to +217
eps = 1e-6
k_norm_sq = (k.float() * k.float()).sum(dim=-1) # [B, T, H]
k_norm_sq = torch.clamp(k_norm_sq, min=0.1, max=10.0)

if beta.dim() == 1:
beta_h = beta.view(1, 1, H).to(k_norm_sq.dtype)
else:
beta_h = beta.to(k_norm_sq.dtype)

beta_h = torch.clamp(beta_h, min=0.01, max=10.0)
# Compute alpha with numerical stability
exp_term = torch.exp(-beta_h * k_norm_sq)
alpha = (1.0 - exp_term) / (k_norm_sq + eps)
beta_tok = alpha.to(dtype=q.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The recomputation of alpha (and beta_tok) in the backward pass is inconsistent with the forward pass. In forward (lines 66-95), alpha is computed using A_log and dt_bias, and beta_tok is the mean of alpha across the key dimension. Here, the recomputation uses a simplified formula that ignores A_log and dt_bias, and misses the mean(dim=-1) reduction. This will lead to incorrect gradients during training.

Comment thread fla/layers/quasar.py

batch_size, q_len, _ = hidden_states.shape
# Force chunk mode to avoid fused_recurrent BT conflict
mode = "chunk"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mode is hardcoded to "chunk", which overrides the mode parameter passed during initialization (self.mode). This prevents the use of the fused_recurrent kernel even if explicitly requested by the user.

Suggested change
mode = "chunk"
mode = self.mode

Comment thread fla/ops/quasar/chunk.py
g=None,
gk=None,
h0=initial_state_f32,
dht=None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The gradient of the final state (d_final_state) is ignored here. It should be passed to chunk_gated_delta_rule_bwd_dhu to support state-passing training.

Suggested change
dht=None,
dht=d_final_state,

Comment thread fla/ops/quasar/chunk.py
# Token-wise gradient doesn't need / T normalization if it's fed to linear layer
dbeta = torch.clamp(dbeta, min=-1.0, max=1.0)

return dq, dk, dv, dbeta, None, None, None, None, None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backward pass is missing gradients for A_log and dt_bias. Additionally, the gradient for initial_state (dh0, computed at line 274) is not being returned, which will break state-passing training.

Suggested change
return dq, dk, dv, dbeta, None, None, None, None, None
return dq, dk, dv, dbeta, None, None, dh0, None, None

Comment thread fla/layers/quasar.py Outdated
Comment on lines +236 to +237
if torch.isnan(q).any(): pass # print("!!! NAN IN Q (before RoPE)", flush=True)
if torch.isnan(v).any(): pass # print("!!! NAN IN V", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These debugging statements with pass should be removed to maintain code cleanliness.

Comment on lines +76 to +78
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Redundant normalization. q and k are already normalized in the QuasarAttention.forward method (lines 265-266). Performing normalization again inside the Triton kernel is unnecessary and impacts performance.

Comment on lines +120 to +127
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))

if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code is redundant as it is duplicated immediately after.

Suggested change
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))


# Backward pass: dL = -A^T @ dA @ A^T
# Simplified implementation for now
dL = torch.zeros_like(L)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The backward pass for ForwardSubstitutionFunction is not implemented and returns zeros. This will result in incorrect gradients if this function is used in a differentiable path.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

🟡 Minor comments (15)
fla/distributed_compat.py-8-8 (1)

8-8: ⚠️ Potential issue | 🟡 Minor

Remove unused torch import.

The torch module is imported but never used in this file. All imports are from torch.distributed submodules, which don't require the parent module to be explicitly imported.

🧹 Proposed fix
-import torch
-
 # DeviceMesh
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/distributed_compat.py` at line 8, Remove the unused top-level import of
the torch module by deleting the solitary "import torch" statement; keep the
existing torch.distributed submodule imports intact (they do not require the
parent torch import) and run a quick search for any references to the plain
"torch" symbol in this file to ensure no other code relies on it before
committing.
fla/ops/quasar/fused_recurrent.py-212-214 (1)

212-214: ⚠️ Potential issue | 🟡 Minor

Backward pass not implemented.

The backward pass raises NotImplementedError. This means fused_recurrent_quasar cannot be used during training with gradient computation. Ensure this limitation is documented, and consider adding a warning when the mode is used in a training context.

Would you like me to help document this limitation or open an issue to track the backward implementation?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/fused_recurrent.py` around lines 212 - 214, The backward
staticmethod in fused_recurrent_quasar currently raises NotImplementedError
preventing training — add clear documentation and a runtime warning instead of
just the exception: update the fused_recurrent_quasar class/docstring to state
that the backward pass is unimplemented and the op is inference-only, and in the
forward (or apply) entrypoint detect gradient-required contexts (e.g.,
torch.is_grad_enabled() or any input.requires_grad) and emit a warning via
warnings.warn or the project logger explaining gradients are not supported and
that training will fail; ensure the backward method still raises if actually
called, and add a TODO/issue reference comment to track implementing
fused_recurrent_quasar.backward.
fla/ops/quasar/chunk_bwd.py-9-9 (1)

9-9: ⚠️ Potential issue | 🟡 Minor

Remove unused import.

IS_NVIDIA_BLACKWELL is imported but never used.

🧹 Proposed fix
-from fla.utils import IS_NVIDIA_HOPPER, IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem
+from fla.utils import IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` at line 9, The import line in chunk_bwd.py
includes an unused symbol IS_NVIDIA_BLACKWELL; remove IS_NVIDIA_BLACKWELL from
the from-import (leaving IS_NVIDIA_HOPPER, autotune_cache_kwargs,
check_shared_mem) so the module no longer imports an unused name and linter
warnings are resolved.
fla/models/quasar/modeling_quasar.py-206-208 (1)

206-208: ⚠️ Potential issue | 🟡 Minor

Add stacklevel=2 to warning for correct caller attribution.

Without explicit stacklevel, the warning will point to this line rather than the caller's location.

🔧 Proposed fix
-            warnings.warn("`QuasarModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`QuasarModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/quasar/modeling_quasar.py` around lines 206 - 208, The warning in
QuasarModel's handling of output_attentions uses warnings.warn without a
stacklevel, so update the call in modeling_quasar.py (the block that checks
output_attentions in the QuasarModel code) to pass stacklevel=2 to warnings.warn
so the warning points to the caller rather than this function; keep the existing
message and behavior but add the stacklevel argument to the warnings.warn
invocation.
fla/ops/quasar/chunk_bwd.py-269-269 (1)

269-269: ⚠️ Potential issue | 🟡 Minor

Use explicit | None instead of implicit Optional.

PEP 484 prohibits implicit Optional for scale: float = None.

🔧 Proposed fix
-    scale: float = None,
+    scale: float | None = None,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` at line 269, Update the function parameter
annotation that currently reads "scale: float = None" to use an explicit
nullable type by changing it to "scale: float | None = None" (or
"Optional[float]" if you prefer typing imports) wherever the parameter appears
in chunk_bwd.py; locate the signature that contains the symbol "scale" in the
function/method that handles chunk backward logic and replace the implicit
Optional pattern with the explicit "| None" union type so it complies with PEP
484.
fla/layers/quasar.py-83-83 (1)

83-83: ⚠️ Potential issue | 🟡 Minor

Remove invalid return type annotation from __init__.

__init__ methods should not have return type annotations other than -> None. The current -> QuasarAttention is invalid.

🔧 Proposed fix
-    ) -> QuasarAttention:
+    ) -> None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` at line 83, The __init__ of class QuasarAttention has
an invalid return annotation (-> QuasarAttention); change the signature of
QuasarAttention.__init__ to have no return type or explicitly -> None so it
conforms to Python constructor typing (remove or replace the existing ->
QuasarAttention annotation).
fla/models/quasar/modeling_quasar.py-295-305 (1)

295-305: ⚠️ Potential issue | 🟡 Minor

Use exception chaining for proper traceback.

When re-raising a modified exception, use raise ... from exception to preserve the original traceback context.

🔧 Proposed fix
             if "past_key_values" in str(exception):
                 raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies",
-                )
+                ) from exception
             else:
                 raise exception
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/quasar/modeling_quasar.py` around lines 295 - 305, In the except
AttributeError as exception block in modeling_quasar.py (the handler that checks
if "past_key_values" is in str(exception)), change the re-raise to use exception
chaining by raising the new AttributeError from the caught exception (i.e., use
"raise AttributeError(... ) from exception") so the original traceback is
preserved; keep the same message that references self.__class__.__name__ and the
generation strategies doc URL.
fla/layers/quasar.py-236-237 (1)

236-237: ⚠️ Potential issue | 🟡 Minor

Remove debug code that causes pipeline failures.

These debug statements with pass on the same line violate E701 and serve no purpose since they're commented out. The pipeline is failing on these lines.

🧹 Proposed fix - remove debug lines
-        if torch.isnan(q).any(): pass # print("!!! NAN IN Q (before RoPE)", flush=True)
-        if torch.isnan(v).any(): pass # print("!!! NAN IN V", flush=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 236 - 237, Remove the two debug
if-statements that combine a conditional and a pass/comment on one line (the
torch.isnan(q).any() and torch.isnan(v).any() checks) in quasar.py; these
violate E701 and are unnecessary — delete those lines (or replace with a proper
multiline check/logging if you need runtime validation) so only valid,
non-commented statements remain around the q and v NaN checks.
fla/layers/quasar.py-80-80 (1)

80-80: ⚠️ Potential issue | 🟡 Minor

Use explicit | None instead of implicit Optional.

PEP 484 prohibits implicit Optional. The layer_idx parameter should use explicit union syntax.

🔧 Proposed fix
-        layer_idx: int = None,
+        layer_idx: int | None = None,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` at line 80, The parameter annotation uses an implicit
Optional by writing "layer_idx: int = None"; update the function/method
signature to use explicit union syntax by changing the annotation to "layer_idx:
int | None = None" (i.e., replace the implicit Optional with int | None for the
layer_idx parameter) so it complies with PEP 484; locate the declaration of
layer_idx in the function or constructor where it appears and apply this change.
fla/layers/quasar.py-6-17 (1)

6-17: ⚠️ Potential issue | 🟡 Minor

Remove unused imports to fix pipeline failures.

The following imports are flagged as unused by Flake8 and are causing pipeline failures:

  • math (line 6)
  • repeat from einops (line 11)
  • RMSNorm from fla.modules (line 15)
  • fused_quasar_gate from fla.ops.quasar.gate (line 17)
🧹 Proposed fix
-import math
 from typing import TYPE_CHECKING

 import torch
 import torch.nn as nn
-from einops import rearrange, repeat
+from einops import rearrange
 from torch.nn import functional as F

 from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
-from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
+from fla.modules import FusedRMSNormGated, ShortConvolution
 from fla.ops.quasar import chunk_quasar, fused_recurrent_quasar
-from fla.ops.quasar.gate import fused_quasar_gate
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 6 - 17, Remove the unused imports causing
Flake8 failures: delete the top-level imports "math", "repeat" (from einops),
"RMSNorm" (from fla.modules) and "fused_quasar_gate" (from fla.ops.quasar.gate)
from the import block so only actually used symbols like torch, nn, rearrange,
get_unpad_data, index_first_axis, pad_input, FusedRMSNormGated,
ShortConvolution, chunk_quasar and fused_recurrent_quasar remain; ensure no
other code references those removed names and run linter to confirm the pipeline
passes.
fla/ops/quasar/chunk_bwd.py-240-243 (1)

240-243: ⚠️ Potential issue | 🟡 Minor

Remove unused local variables flagged by pipeline.

Variables m_k (line 190), b_q (line 241), and b_kdk (line 242) are assigned but never used, causing pipeline failures.

🧹 Proposed fix
     for i_k in range(tl.cdiv(K, BK)):
-        o_k = i_k * BK + tl.arange(0, BK)
-        m_k = o_k < K
+        # o_k and m_k removed - not used

...
-        p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
-        b_q = tl.load(p_q, boundary_check=(0, 1))
-        b_kdk = b_k * b_dk
         b_dk = b_dk + b_dkgb * b_beta[:, None]

Note: Review if o_k is actually needed elsewhere in the loop before removing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 240 - 243, Remove the unused local
variables to fix the pipeline errors: delete or rename the unused assignment to
m_k, and remove or replace the unused b_q and b_kdk assignments in the loop (the
tl.load into b_q and the temporary b_kdk = b_k * b_dk); if the tl.load call must
remain for side-effects, assign its result to _ instead of b_q. Also review use
of o_k in the same loop to ensure it’s still needed before removing any related
code.
fla/ops/quasar/chunk_intra.py-392-392 (1)

392-392: ⚠️ Potential issue | 🟡 Minor

Avoid shadowing Python builtin all.

The variable name all shadows the Python builtin function. Use a different name like total_tokens or num_elements.

🔧 Proposed fix
-    all = B * T
+    total_tokens = B * T
     if IS_VARLEN:
         ...
     ...
-    db += (i_k * all + bos) * H + i_h
+    db += (i_k * total_tokens + bos) * H + i_h
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_intra.py` at line 392, The variable name "all" in
chunk_intra.py (the assignment "all = B * T") shadows the built-in all(); rename
it to a non-conflicting identifier (e.g., total_tokens or num_elements) and
update every reference in the same scope (and any return/formatting/logging that
uses it) to the new name so behavior is unchanged but the builtin is no longer
shadowed; ensure imports or other functions are not affected and run
tests/static checks after renaming.
fla/ops/quasar/chunk.py-5-13 (1)

5-13: ⚠️ Potential issue | 🟡 Minor

Remove unused imports.

Multiple imports are unused according to static analysis and cause linting failures:

  • triton (line 5)
  • chunk_gla_fwd_o_gk (line 9)
  • fused_quasar_gate, fast_quasar_alpha (line 11)
  • autotune_cache_kwargs (line 12)
  • chunk_bwd_dv_local, chunk_bwd_dqkwg (line 13)
🧹 Proposed fix
 import torch
-import triton

 from fla.ops.utils.index import prepare_chunk_indices
 from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
-from fla.ops.gla.chunk import chunk_gla_fwd_o_gk
 from fla.ops.quasar.chunk_intra import chunk_quasar_fwd_intra
-from fla.ops.quasar.gate import fused_quasar_gate, fast_quasar_alpha
-from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard
-from fla.ops.common.chunk_o import chunk_fwd_o, chunk_bwd_dv_local, chunk_bwd_dqkwg
+from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
+from fla.ops.common.chunk_o import chunk_fwd_o
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk.py` around lines 5 - 13, Remove the unused imports
causing lint errors: delete the triton import and the unused symbols
chunk_gla_fwd_o_gk, fused_quasar_gate, fast_quasar_alpha, autotune_cache_kwargs,
chunk_bwd_dv_local, and chunk_bwd_dqkwg from the import lists in this file;
specifically update the top imports that currently reference triton,
fla.ops.gla.chunk (chunk_gla_fwd_o_gk), fla.ops.quasar.gate (fused_quasar_gate,
fast_quasar_alpha), fla.utils (autotune_cache_kwargs), and
fla.ops.common.chunk_o (chunk_bwd_dv_local, chunk_bwd_dqkwg) so only actually
used names remain, and run tests/linter to confirm no other references to those
symbols exist.
fla/ops/quasar/gate.py-5-5 (1)

5-5: ⚠️ Potential issue | 🟡 Minor

Remove the unused torch.nn.functional import.

Flake8 is already flagging Line 5, so F keeps the lint job red until it is dropped.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/gate.py` at line 5, The import torch.nn.functional as F in
gate.py is unused and causes a Flake8 lint failure; remove the unused import
(the "F" symbol) from the top of the file so only necessary torch imports
remain, e.g., delete the line containing "import torch.nn.functional as F" or
replace it with needed imports used elsewhere in the module.
fla/ops/quasar/forward_substitution.py-8-8 (1)

8-8: ⚠️ Potential issue | 🟡 Minor

Remove the unused check_shared_mem import.

Nothing in this file references it, so Flake8 will keep reporting F401 here.

✂️ Proposed fix
-from fla.utils import IS_AMD, autotune_cache_kwargs, check_shared_mem, input_guard
+from fla.utils import IS_AMD, autotune_cache_kwargs, input_guard
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/forward_substitution.py` at line 8, Remove the unused import
check_shared_mem from the module import list; update the from fla.utils import
line in forward_substitution.py to only import the actually used symbols (e.g.,
IS_AMD, autotune_cache_kwargs, input_guard) so Flake8 F401 is resolved and no
other references to check_shared_mem remain.
🧹 Nitpick comments (7)
fla/distributed_compat.py (1)

48-57: Consider sorting __all__ alphabetically.

Static analysis suggests applying alphabetical sorting to __all__ for consistency.

📝 Proposed sort order
 __all__ = [
     'DeviceMesh',
     'DTensor',
+    'HAS_DISTRIBUTED',
+    'ParallelStyle',
     'Placement',
     'Replicate',
     'Shard',
     'distribute_module',
-    'ParallelStyle',
-    'HAS_DISTRIBUTED',
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/distributed_compat.py` around lines 48 - 57, The __all__ export list is
unsorted; please reorder the entries in the __all__ list alphabetically for
consistency—locate the __all__ definition and change the sequence of the symbols
('DeviceMesh', 'DTensor', 'Placement', 'Replicate', 'Shard',
'distribute_module', 'ParallelStyle', 'HAS_DISTRIBUTED') so they appear in
ascending alphabetical order (by name) while preserving the exact identifier
names and quotes.
fla/models/quasar/configuration_quasar.py (1)

1-5: Minor: Remove leading blank lines.

The file starts with two empty lines before the import statement. Consider removing them for consistency.

✨ Proposed fix
-
-
 from transformers.configuration_utils import PretrainedConfig
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/quasar/configuration_quasar.py` around lines 1 - 5, The file
begins with unnecessary leading blank lines before the import; remove the empty
lines so the first non-blank line is the import statement (from
transformers.configuration_utils import PretrainedConfig) to keep the file tidy
and consistent.
fla/ops/quasar/chunk.py (1)

104-117: Unused unpacked variables should be prefixed with underscore.

Variables qg, Aqk, and Akk (line 104) are unpacked but never used. The same applies in backward (line 226). Prefix with _ to indicate intentionally unused.

♻️ Proposed fix for forward
-    w, u, qg, kg, Aqk, Akk = chunk_quasar_fwd_intra(
+    w, u, _qg, kg, _Aqk, _Akk = chunk_quasar_fwd_intra(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk.py` around lines 104 - 117, The unpacked outputs qg,
Aqk, and Akk returned from chunk_quasar_fwd_intra (and the corresponding unused
outputs in the backward function, e.g., from chunk_quasar_bwd_intra) are never
used; rename them to _qg, _Aqk, and _Akk (and the backward equivalents) to mark
them as intentionally unused and avoid linter warnings—update the unpack targets
where chunk_quasar_fwd_intra and chunk_quasar_bwd_intra are called to use the
underscore-prefixed names.
fla/layers/quasar.py (1)

318-319: Use contextlib.suppress instead of bare try-except-pass.

The pipeline is flagging this pattern. Using contextlib.suppress(TypeError) is more Pythonic.

♻️ Proposed refactor
+import contextlib
...
-                try:
-                    past_key_values.update(
-                        recurrent_state=recurrent_state,
-                        conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
-                        layer_idx=self.layer_idx,
-                        offset=q_len,
-                    )
-                except TypeError:
-                    pass
+                with contextlib.suppress(TypeError):
+                    past_key_values.update(
+                        recurrent_state=recurrent_state,
+                        conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
+                        layer_idx=self.layer_idx,
+                        offset=q_len,
+                    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 318 - 319, Replace the bare try/except
TypeError: pass block with contextlib.suppress(TypeError) to be more Pythonic:
import contextlib at the top (if missing) and wrap the statement(s) that
previously lived inside the try block with a with
contextlib.suppress(TypeError): ... block; locate the existing try/except in the
quasar-related function or method where the except TypeError: pass appears and
convert that block to use the context manager (preserving the same inner
statements and indentation).
fla/ops/quasar/chunk_intra.py (1)

253-272: Forward substitution loops have hardcoded BC=16 assumption.

The loop ranges range(2, min(BC, ...)), range(BC + 2, min(2*BC, ...)), etc., assume BC=16 based on the 4 sub-chunks pattern. However, BC is a compile-time constant passed to the kernel. If BC were changed, the indexing math (i - BC, i - 2*BC, i - 3*BC) would break.

Consider adding an assertion or documentation that BC must be 16 for this kernel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_intra.py` around lines 253 - 272, The
forward-substitution loops in chunk_intra.py assume a 4-subchunk layout by
hardcoding offsets like i-BC, i-2*BC, i-3*BC and loop starts (e.g., range(2,
min(BC, ...))) which will break if the compile-time constant BC is not 16;
update the kernel to either enforce BC==16 or generalize the loops: add a
runtime assertion at kernel entry (reference BC) or compute the number of
subchunks = ceil(H*BC/BC) (or use a loop over subchunk_idx and replace the four
explicit blocks (b_a00/b_Ai00, b_a11/b_Ai11, b_a22/b_Ai22, b_a33/b_Ai33) with a
single parametric block that uses offset = subchunk_idx*BC and compares o_i < i
- offset, and updates the corresponding b_Ai array by index; ensure any naming
(b_a00, b_Ai00, etc.) is replaced by indexed containers or mapped by
subchunk_idx to avoid hardcoded 4-way logic.
fla/ops/quasar/gate.py (1)

16-39: Make the reference path match the kernel's compute precision.

quasar_gate_fwd_kernel promotes beta and lambda_t to fp32 before exp, but naive_quasar_gate() does the whole computation in the incoming dtype and only casts at the end. On bf16/fp16 inputs that means the "reference" path can disagree with the Triton path for dtype reasons alone.

♻️ Proposed alignment
 def naive_quasar_gate(
     beta: torch.Tensor,
     lambda_t: torch.Tensor,
     output_dtype: torch.dtype = torch.float32,
 ) -> torch.Tensor:
     """
     Torch reference implementation for QuasarAttention gate computation.
@@
     """
     eps = 1e-8
-    alpha = (1 - torch.exp(-beta.view(-1, 1) * lambda_t)) / (lambda_t + eps)
+    beta_f = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1).to(torch.float32)
+    lambda_f = lambda_t.to(torch.float32)
+    alpha = (1 - torch.exp(-beta_f * lambda_f)) / (lambda_f + eps)
     return alpha.to(output_dtype)

Also applies to: 65-75

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/gate.py` around lines 16 - 39, The reference implementation
naive_quasar_gate runs the math in the incoming dtype which diverges from
quasar_gate_fwd_kernel that promotes inputs to fp32; to fix, cast beta and
lambda_t to torch.float32 (e.g., beta.view(-1,1).to(torch.float32) and
lambda_t.to(torch.float32)), perform the exp and division in fp32 (use an fp32
eps constant), then cast the resulting alpha back to output_dtype before
returning; apply the same change to the corresponding reference backward
function used around lines 65-75 so both forward and backward reference paths
match the kernel's compute precision.
fla/ops/quasar/forward_substitution.py (1)

10-10: The autotune configs bypass NUM_WARPS.

Line 10 defines the architecture-specific warp candidates, but Lines 15-18 rebuild that list as [2, 4, 8], so the AMD/CUDA split never actually applies and CUDA never tries the 16/32-warp cases. Please use NUM_WARPS here, or delete it, so this tuning surface has one source of truth.

⚙️ Proposed fix
 `@triton.autotune`(
     configs=[
         triton.Config({}, num_warps=num_warps, num_stages=num_stages)
-        for num_warps in [2, 4, 8]
+        for num_warps in NUM_WARPS
         for num_stages in [2, 3, 4]
     ],
Based on learnings: Align threshold constants used by check_shared_mem to a single source of truth to avoid semantic drift, rather than duplicating literals.

Also applies to: 13-20

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/forward_substitution.py` at line 10, The autotune candidate
list currently defined by NUM_WARPS is being overwritten later in
forward_substitution.py (the autotune configs that rebuild the list as [2,4,8]),
so the AMD/CUDA split in NUM_WARPS is never used; change the autotune config
code to reference NUM_WARPS instead of rebuilding a hardcoded list (or remove
the duplicate list entirely) so there is a single source of truth, and likewise
consolidate any threshold literals used by check_shared_mem into the same shared
constant so check_shared_mem and the autotuner use the same value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/layers/quasar.py`:
- Around line 54-66: QuasarBlock is passing expand_v, num_v_heads, and
allow_neg_eigval into QuasarAttention.__init__ but QuasarAttention's signature
lacks these parameters and they are swallowed by **kwargs; either add explicit
parameters expand_v, num_v_heads, and allow_neg_eigval to
QuasarAttention.__init__ (and use them inside the class where appropriate) or
stop passing them from QuasarBlock and ensure the intended behavior (e.g.,
implement the missing logic or rename/forward the correct config fields). Update
the QuasarAttention class signature and internal usage (or the QuasarBlock
invocation) so the three symbols are consistently defined and handled rather
than silently ignored.

In `@fla/models/quasar/modeling_quasar.py`:
- Around line 54-66: QuasarBlock is passing expand_v, num_v_heads, and
allow_neg_eigval into QuasarAttention but those args are not in
QuasarAttention.__init__, so they get swallowed by **kwargs; fix by either (A)
adding explicit parameters expand_v, num_v_heads, allow_neg_eigval to
QuasarAttention.__init__ in fla/layers/quasar.py and wiring them into the
attention behavior, or (B) removing those three arguments from the QuasarBlock
instantiation (the self.attn = QuasarAttention(...) call) so only the supported
parameters (hidden_size, head_dim, num_heads, mode, use_short_conv, conv_size,
conv_bias, layer_idx, norm_eps) are passed; choose the option consistent with
intended functionality and update any internal uses accordingly.

In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 278-283: The elif currently checks the function object instead of
calling it (elif check_shared_mem:) causing CONST_TILING to always become 64;
change that branch to call check_shared_mem with the device index (e.g., elif
check_shared_mem(k.device.index):) so it evaluates available shared memory
correctly and sets CONST_TILING to 64 only when the call returns true; keep the
first branch using check_shared_mem('hopper', k.device.index) unchanged.
- Around line 11-20: The safe_dot function currently uses NVIDIA-only PTX via
tl.inline_asm_elementwise (asm="mov.f32 ...") which will break on AMD; either
guard its use with the existing NVIDIA detection flags (e.g., check
IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL before calling safe_dot or raising a
clear error) or replace the inline asm with a device-agnostic implementation
(for example, return tl.dot(a, b) or an equivalent pure Triton expression) so
that safe_dot and callers (safe_dot) work on non‑NVIDIA devices; update
references to safe_dot and any call sites accordingly.

In `@fla/ops/quasar/chunk_intra_token_parallel.py`:
- Around line 109-110: The code unconditionally creates p_beta_out and calls
tl.store into beta_out even when beta_out can be None; wrap the creation of
p_beta_out and the tl.store call in the same USE_QUASAR_ALPHA guard (or check
beta_out is not None) so the pointer is only made and written when beta_out
exists; specifically, guard the block that references p_beta_out and b_alpha
(the tl.make_block_ptr and tl.store lines) with the USE_QUASAR_ALPHA condition
used earlier so no write occurs when beta_out is None.

In `@fla/ops/quasar/chunk.py`:
- Around line 203-217: The backward pass must reconstruct the exact forward
alpha formula using the saved A_log and dt_bias and compute gradients via the
chain rule; replace the simplified exp(-beta_h * k_norm_sq) path in the backward
method with the forward-equivalent steps: recover A = exp(A_log) and expand beta
to beta_expanded and dt_bias to dt_bias_full to match k_norm_sq shape, compute
g_quasar = -A * softplus(beta_expanded + dt_bias_full), set decay =
exp(g_quasar) and alpha = (1.0 - decay) / (k_norm_sq + eps) (same as forward),
then compute dalpha_dbeta using dalpha/dg * dg/dbeta = ( -decay /
(k_norm_sq+eps) ) * ( -A * sigmoid(beta_expanded + dt_bias_full) ) => A * decay
* sigmoid(...) / (k_norm_sq+eps) and use that in the existing beta gradient
calculations (replace the current dalpha_dbeta that uses k_norm_sq * exp_term /
(k_norm_sq + eps)); ensure you use the saved tensors A_log and dt_bias and the
same eps, softplus, and shapes as in forward (variables: A_log, dt_bias,
beta_h/beta_expanded, g_quasar, decay, alpha).

In `@fla/ops/quasar/forward_substitution.py`:
- Around line 107-119: The backward implementation for
quasar_forward_substitution currently returns zeros_like(L) which silently
blocks gradients; either implement the correct VJP using the saved L and A
(compute dL = -A.transpose(-2,-1) @ dA @ A.transpose(-2,-1) with appropriate
tensor shapes and returns) or, if this op is inference/recompute-only, replace
the zero return with an explicit error (e.g. raise RuntimeError) inside the
static backward(ctx, dA) of the custom autograd Function to prevent silent
gradient masking; locate the backward method where ctx.saved_tensors yields L
and A and update that method accordingly.

In `@fla/ops/quasar/fused_recurrent.py`:
- Around line 120-126: The code contains a duplicated STORE_FINAL_STATE block
that computes p_ht and calls tl.store twice; remove the redundant second block
to ensure final_state is stored only once. Locate the repeated conditional using
STORE_FINAL_STATE that computes p_ht from final_state, i_b, H, i_h, BK, o_k, o_v
and calls tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)), and delete the
duplicate so only a single STORE_FINAL_STATE branch performs the tl.store of
b_h.
- Around line 76-78: The kernel is L2-normalizing b_q and b_k (controlled by
USE_QK_L2NORM_IN_KERNEL) while QuasarAttention.forward already normalizes q and
k with F.normalize; to avoid double normalization, update
QuasarAttention.forward so that when calling fused_recurrent_quasar it either
does not pre-normalize (remove the F.normalize(q, p=2, dim=-1)/F.normalize(k,
p=2, dim=-1) calls) if you want the kernel to handle normalization, or pass
use_qk_l2norm_in_kernel=False to fused_recurrent_quasar so the kernel skips its
own normalization; adjust the call site in QuasarAttention.forward and ensure
consistency with the USE_QK_L2NORM_IN_KERNEL flag and any related parameters.

In `@fla/ops/quasar/gate.py`:
- Around line 125-145: QuasarGateFunction.backward currently returns four
gradients despite forward taking three inputs and computes dbeta incorrectly and
with a wrong reduction; update backward in QuasarGateFunction to (1) return
exactly three gradients to match forward's inputs, (2) compute dbeta =
exp(-beta.view(...)*lambda_t) * dalpha * (lambda_t / (lambda_t + eps)) (i.e.,
multiply the existing dbeta term by lambda_t/(lambda_t+eps)), and (3) replace
the hardcoded .sum(dim=(0,1)) with a reduction that sums over all dimensions of
dalpha except the beta dimension so the resulting dbeta matches beta.shape
(similarly sum dlambda over the non-lambda dimensions so dlambda matches
lambda_t.shape); use ctx.saved_tensors (lambda_t, beta) to determine which dims
to reduce.

---

Minor comments:
In `@fla/distributed_compat.py`:
- Line 8: Remove the unused top-level import of the torch module by deleting the
solitary "import torch" statement; keep the existing torch.distributed submodule
imports intact (they do not require the parent torch import) and run a quick
search for any references to the plain "torch" symbol in this file to ensure no
other code relies on it before committing.

In `@fla/layers/quasar.py`:
- Line 83: The __init__ of class QuasarAttention has an invalid return
annotation (-> QuasarAttention); change the signature of
QuasarAttention.__init__ to have no return type or explicitly -> None so it
conforms to Python constructor typing (remove or replace the existing ->
QuasarAttention annotation).
- Around line 236-237: Remove the two debug if-statements that combine a
conditional and a pass/comment on one line (the torch.isnan(q).any() and
torch.isnan(v).any() checks) in quasar.py; these violate E701 and are
unnecessary — delete those lines (or replace with a proper multiline
check/logging if you need runtime validation) so only valid, non-commented
statements remain around the q and v NaN checks.
- Line 80: The parameter annotation uses an implicit Optional by writing
"layer_idx: int = None"; update the function/method signature to use explicit
union syntax by changing the annotation to "layer_idx: int | None = None" (i.e.,
replace the implicit Optional with int | None for the layer_idx parameter) so it
complies with PEP 484; locate the declaration of layer_idx in the function or
constructor where it appears and apply this change.
- Around line 6-17: Remove the unused imports causing Flake8 failures: delete
the top-level imports "math", "repeat" (from einops), "RMSNorm" (from
fla.modules) and "fused_quasar_gate" (from fla.ops.quasar.gate) from the import
block so only actually used symbols like torch, nn, rearrange, get_unpad_data,
index_first_axis, pad_input, FusedRMSNormGated, ShortConvolution, chunk_quasar
and fused_recurrent_quasar remain; ensure no other code references those removed
names and run linter to confirm the pipeline passes.

In `@fla/models/quasar/modeling_quasar.py`:
- Around line 206-208: The warning in QuasarModel's handling of
output_attentions uses warnings.warn without a stacklevel, so update the call in
modeling_quasar.py (the block that checks output_attentions in the QuasarModel
code) to pass stacklevel=2 to warnings.warn so the warning points to the caller
rather than this function; keep the existing message and behavior but add the
stacklevel argument to the warnings.warn invocation.
- Around line 295-305: In the except AttributeError as exception block in
modeling_quasar.py (the handler that checks if "past_key_values" is in
str(exception)), change the re-raise to use exception chaining by raising the
new AttributeError from the caught exception (i.e., use "raise
AttributeError(... ) from exception") so the original traceback is preserved;
keep the same message that references self.__class__.__name__ and the generation
strategies doc URL.

In `@fla/ops/quasar/chunk_bwd.py`:
- Line 9: The import line in chunk_bwd.py includes an unused symbol
IS_NVIDIA_BLACKWELL; remove IS_NVIDIA_BLACKWELL from the from-import (leaving
IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem) so the module no
longer imports an unused name and linter warnings are resolved.
- Line 269: Update the function parameter annotation that currently reads
"scale: float = None" to use an explicit nullable type by changing it to "scale:
float | None = None" (or "Optional[float]" if you prefer typing imports)
wherever the parameter appears in chunk_bwd.py; locate the signature that
contains the symbol "scale" in the function/method that handles chunk backward
logic and replace the implicit Optional pattern with the explicit "| None" union
type so it complies with PEP 484.
- Around line 240-243: Remove the unused local variables to fix the pipeline
errors: delete or rename the unused assignment to m_k, and remove or replace the
unused b_q and b_kdk assignments in the loop (the tl.load into b_q and the
temporary b_kdk = b_k * b_dk); if the tl.load call must remain for side-effects,
assign its result to _ instead of b_q. Also review use of o_k in the same loop
to ensure it’s still needed before removing any related code.

In `@fla/ops/quasar/chunk_intra.py`:
- Line 392: The variable name "all" in chunk_intra.py (the assignment "all = B *
T") shadows the built-in all(); rename it to a non-conflicting identifier (e.g.,
total_tokens or num_elements) and update every reference in the same scope (and
any return/formatting/logging that uses it) to the new name so behavior is
unchanged but the builtin is no longer shadowed; ensure imports or other
functions are not affected and run tests/static checks after renaming.

In `@fla/ops/quasar/chunk.py`:
- Around line 5-13: Remove the unused imports causing lint errors: delete the
triton import and the unused symbols chunk_gla_fwd_o_gk, fused_quasar_gate,
fast_quasar_alpha, autotune_cache_kwargs, chunk_bwd_dv_local, and
chunk_bwd_dqkwg from the import lists in this file; specifically update the top
imports that currently reference triton, fla.ops.gla.chunk (chunk_gla_fwd_o_gk),
fla.ops.quasar.gate (fused_quasar_gate, fast_quasar_alpha), fla.utils
(autotune_cache_kwargs), and fla.ops.common.chunk_o (chunk_bwd_dv_local,
chunk_bwd_dqkwg) so only actually used names remain, and run tests/linter to
confirm no other references to those symbols exist.

In `@fla/ops/quasar/forward_substitution.py`:
- Line 8: Remove the unused import check_shared_mem from the module import list;
update the from fla.utils import line in forward_substitution.py to only import
the actually used symbols (e.g., IS_AMD, autotune_cache_kwargs, input_guard) so
Flake8 F401 is resolved and no other references to check_shared_mem remain.

In `@fla/ops/quasar/fused_recurrent.py`:
- Around line 212-214: The backward staticmethod in fused_recurrent_quasar
currently raises NotImplementedError preventing training — add clear
documentation and a runtime warning instead of just the exception: update the
fused_recurrent_quasar class/docstring to state that the backward pass is
unimplemented and the op is inference-only, and in the forward (or apply)
entrypoint detect gradient-required contexts (e.g., torch.is_grad_enabled() or
any input.requires_grad) and emit a warning via warnings.warn or the project
logger explaining gradients are not supported and that training will fail;
ensure the backward method still raises if actually called, and add a TODO/issue
reference comment to track implementing fused_recurrent_quasar.backward.

In `@fla/ops/quasar/gate.py`:
- Line 5: The import torch.nn.functional as F in gate.py is unused and causes a
Flake8 lint failure; remove the unused import (the "F" symbol) from the top of
the file so only necessary torch imports remain, e.g., delete the line
containing "import torch.nn.functional as F" or replace it with needed imports
used elsewhere in the module.

---

Nitpick comments:
In `@fla/distributed_compat.py`:
- Around line 48-57: The __all__ export list is unsorted; please reorder the
entries in the __all__ list alphabetically for consistency—locate the __all__
definition and change the sequence of the symbols ('DeviceMesh', 'DTensor',
'Placement', 'Replicate', 'Shard', 'distribute_module', 'ParallelStyle',
'HAS_DISTRIBUTED') so they appear in ascending alphabetical order (by name)
while preserving the exact identifier names and quotes.

In `@fla/layers/quasar.py`:
- Around line 318-319: Replace the bare try/except TypeError: pass block with
contextlib.suppress(TypeError) to be more Pythonic: import contextlib at the top
(if missing) and wrap the statement(s) that previously lived inside the try
block with a with contextlib.suppress(TypeError): ... block; locate the existing
try/except in the quasar-related function or method where the except TypeError:
pass appears and convert that block to use the context manager (preserving the
same inner statements and indentation).

In `@fla/models/quasar/configuration_quasar.py`:
- Around line 1-5: The file begins with unnecessary leading blank lines before
the import; remove the empty lines so the first non-blank line is the import
statement (from transformers.configuration_utils import PretrainedConfig) to
keep the file tidy and consistent.

In `@fla/ops/quasar/chunk_intra.py`:
- Around line 253-272: The forward-substitution loops in chunk_intra.py assume a
4-subchunk layout by hardcoding offsets like i-BC, i-2*BC, i-3*BC and loop
starts (e.g., range(2, min(BC, ...))) which will break if the compile-time
constant BC is not 16; update the kernel to either enforce BC==16 or generalize
the loops: add a runtime assertion at kernel entry (reference BC) or compute the
number of subchunks = ceil(H*BC/BC) (or use a loop over subchunk_idx and replace
the four explicit blocks (b_a00/b_Ai00, b_a11/b_Ai11, b_a22/b_Ai22,
b_a33/b_Ai33) with a single parametric block that uses offset = subchunk_idx*BC
and compares o_i < i - offset, and updates the corresponding b_Ai array by
index; ensure any naming (b_a00, b_Ai00, etc.) is replaced by indexed containers
or mapped by subchunk_idx to avoid hardcoded 4-way logic.

In `@fla/ops/quasar/chunk.py`:
- Around line 104-117: The unpacked outputs qg, Aqk, and Akk returned from
chunk_quasar_fwd_intra (and the corresponding unused outputs in the backward
function, e.g., from chunk_quasar_bwd_intra) are never used; rename them to _qg,
_Aqk, and _Akk (and the backward equivalents) to mark them as intentionally
unused and avoid linter warnings—update the unpack targets where
chunk_quasar_fwd_intra and chunk_quasar_bwd_intra are called to use the
underscore-prefixed names.

In `@fla/ops/quasar/forward_substitution.py`:
- Line 10: The autotune candidate list currently defined by NUM_WARPS is being
overwritten later in forward_substitution.py (the autotune configs that rebuild
the list as [2,4,8]), so the AMD/CUDA split in NUM_WARPS is never used; change
the autotune config code to reference NUM_WARPS instead of rebuilding a
hardcoded list (or remove the duplicate list entirely) so there is a single
source of truth, and likewise consolidate any threshold literals used by
check_shared_mem into the same shared constant so check_shared_mem and the
autotuner use the same value.

In `@fla/ops/quasar/gate.py`:
- Around line 16-39: The reference implementation naive_quasar_gate runs the
math in the incoming dtype which diverges from quasar_gate_fwd_kernel that
promotes inputs to fp32; to fix, cast beta and lambda_t to torch.float32 (e.g.,
beta.view(-1,1).to(torch.float32) and lambda_t.to(torch.float32)), perform the
exp and division in fp32 (use an fp32 eps constant), then cast the resulting
alpha back to output_dtype before returning; apply the same change to the
corresponding reference backward function used around lines 65-75 so both
forward and backward reference paths match the kernel's compute precision.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 3030bcab-3065-48c1-b724-442195645a70

📥 Commits

Reviewing files that changed from the base of the PR and between ca3905f and 1d92a63.

📒 Files selected for processing (17)
  • PR_DESCRIPTION.md
  • fla/distributed_compat.py
  • fla/layers/__init__.py
  • fla/layers/quasar.py
  • fla/models/__init__.py
  • fla/models/quasar/__init__.py
  • fla/models/quasar/configuration_quasar.py
  • fla/models/quasar/modeling_quasar.py
  • fla/ops/quasar/__init__.py
  • fla/ops/quasar/chunk.py
  • fla/ops/quasar/chunk_bwd.py
  • fla/ops/quasar/chunk_intra.py
  • fla/ops/quasar/chunk_intra_token_parallel.py
  • fla/ops/quasar/forward_substitution.py
  • fla/ops/quasar/fused_recurrent.py
  • fla/ops/quasar/gate.py
  • fla/ops/quasar/wy_fast.py

Comment thread fla/layers/quasar.py
Comment on lines +54 to +66
The number of heads. Default: 16.
mode (str, Optional):
Which QuasarAttention kernel to use.
Currently available: `chunk` and `fused_recurrent`.
Default: `chunk`.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `True`.
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
layer_idx (int, Optional):
The index of the layer. Default: None.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check QuasarAttention init signature
echo "=== QuasarAttention __init__ parameters ==="
rg -n "def __init__" -A 20 fla/layers/quasar.py | head -30

echo ""
echo "=== QuasarBlock instantiation of QuasarAttention ==="
rg -n "QuasarAttention\(" -A 15 fla/models/quasar/modeling_quasar.py

Repository: fla-org/flash-linear-attention

Length of output: 1585


🏁 Script executed:

# Check if expand_v, num_v_heads, allow_neg_eigval are used in QuasarAttention
echo "=== Checking usage of expand_v in QuasarAttention ===" 
rg "expand_v" fla/layers/quasar.py

echo ""
echo "=== Checking usage of num_v_heads in QuasarAttention ===" 
rg "num_v_heads" fla/layers/quasar.py

echo ""
echo "=== Checking usage of allow_neg_eigval in QuasarAttention ===" 
rg "allow_neg_eigval" fla/layers/quasar.py

echo ""
echo "=== Full QuasarAttention class definition ===" 
rg -n "class QuasarAttention" -A 150 fla/layers/quasar.py | head -160

Repository: fla-org/flash-linear-attention

Length of output: 7280


Parameter mismatch between QuasarBlock and QuasarAttention.

QuasarBlock passes expand_v, num_v_heads, and allow_neg_eigval to QuasarAttention.__init__, but these parameters are not defined in the signature and never used within the class. They are silently swallowed by **kwargs, indicating either incomplete implementation or incorrect configuration propagation from the model to the layer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 54 - 66, QuasarBlock is passing expand_v,
num_v_heads, and allow_neg_eigval into QuasarAttention.__init__ but
QuasarAttention's signature lacks these parameters and they are swallowed by
**kwargs; either add explicit parameters expand_v, num_v_heads, and
allow_neg_eigval to QuasarAttention.__init__ (and use them inside the class
where appropriate) or stop passing them from QuasarBlock and ensure the intended
behavior (e.g., implement the missing logic or rename/forward the correct config
fields). Update the QuasarAttention class signature and internal usage (or the
QuasarBlock invocation) so the three symbols are consistently defined and
handled rather than silently ignored.

Comment on lines +54 to +66
self.attn = QuasarAttention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_v=config.expand_v,
head_dim=config.head_dim,
num_heads=config.num_heads,
num_v_heads=config.num_v_heads,
use_short_conv=config.use_short_conv,
allow_neg_eigval=config.allow_neg_eigval,
conv_size=config.conv_size,
norm_eps=config.norm_eps,
layer_idx=layer_idx,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

QuasarAttention instantiated with undefined parameters.

QuasarBlock passes expand_v, num_v_heads, and allow_neg_eigval to QuasarAttention, but checking fla/layers/quasar.py, these parameters are not in the __init__ signature. They'll be captured by **kwargs and silently ignored.

Either add these parameters to QuasarAttention.__init__ or remove them from the instantiation here.

🔍 Expected QuasarAttention parameters vs. passed

QuasarAttention accepts:

  • hidden_size, head_dim, num_heads, mode, use_short_conv, conv_size, conv_bias, layer_idx, norm_eps

QuasarBlock passes (extra):

  • expand_v (line 57)
  • num_v_heads (line 60)
  • allow_neg_eigval (line 62)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/models/quasar/modeling_quasar.py` around lines 54 - 66, QuasarBlock is
passing expand_v, num_v_heads, and allow_neg_eigval into QuasarAttention but
those args are not in QuasarAttention.__init__, so they get swallowed by
**kwargs; fix by either (A) adding explicit parameters expand_v, num_v_heads,
allow_neg_eigval to QuasarAttention.__init__ in fla/layers/quasar.py and wiring
them into the attention behavior, or (B) removing those three arguments from the
QuasarBlock instantiation (the self.attn = QuasarAttention(...) call) so only
the supported parameters (hidden_size, head_dim, num_heads, mode,
use_short_conv, conv_size, conv_bias, layer_idx, norm_eps) are passed; choose
the option consistent with intended functionality and update any internal uses
accordingly.

Comment on lines +11 to +20
@triton.jit
def safe_dot(a, b):
return tl.inline_asm_elementwise(
asm="mov.f32 $0, $1;",
constraints="=r,r",
args=[tl.dot(a, b)],
dtype=tl.float32,
is_pure=True,
pack=1,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if IS_AMD is used to guard this function
rg -n "IS_AMD|safe_dot" fla/ops/quasar/chunk_bwd.py

Repository: fla-org/flash-linear-attention

Length of output: 188


🏁 Script executed:

# Read the imports and first 50 lines to see if IS_AMD is imported
head -50 fla/ops/quasar/chunk_bwd.py

Repository: fla-org/flash-linear-attention

Length of output: 1276


🏁 Script executed:

# Check if IS_AMD is used anywhere in the module, including call sites
rg -n "IS_AMD" fla/ops/quasar/chunk_bwd.py

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

# Look for any conditional logic around safe_dot calls
rg -B5 -A5 "safe_dot" fla/ops/quasar/chunk_bwd.py | head -60

Repository: fla-org/flash-linear-attention

Length of output: 1107


safe_dot uses unguarded NVIDIA-specific inline assembly.

The tl.inline_asm_elementwise with mov.f32 PTX instruction will fail on AMD GPUs. The function has no guards despite the module importing other NVIDIA-specific checks (IS_NVIDIA_HOPPER, IS_NVIDIA_BLACKWELL). Either guard this function against AMD usage or use a device-agnostic approach.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, The safe_dot function
currently uses NVIDIA-only PTX via tl.inline_asm_elementwise (asm="mov.f32 ...")
which will break on AMD; either guard its use with the existing NVIDIA detection
flags (e.g., check IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL before calling
safe_dot or raising a clear error) or replace the inline asm with a
device-agnostic implementation (for example, return tl.dot(a, b) or an
equivalent pure Triton expression) so that safe_dot and callers (safe_dot) work
on non‑NVIDIA devices; update references to safe_dot and any call sites
accordingly.

Comment on lines +278 to +283
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Missing parentheses in check_shared_mem call.

Line 280 uses elif check_shared_mem: which checks if the function object is truthy (always true), not the result of calling it. This causes CONST_TILING = 64 to always be selected on non-Hopper hardware, even when shared memory is insufficient.

🐛 Proposed fix
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem(device_index=k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem(device_index=k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 278 - 283, The elif currently
checks the function object instead of calling it (elif check_shared_mem:)
causing CONST_TILING to always become 64; change that branch to call
check_shared_mem with the device index (e.g., elif
check_shared_mem(k.device.index):) so it evaluates available shared memory
correctly and sets CONST_TILING to 64 only when the call returns true; keep the
first branch using check_shared_mem('hopper', k.device.index) unchanged.

Comment on lines +109 to +110
p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Unconditional write to beta_out when it may be None.

Lines 109-110 create a block pointer to beta_out and store to it unconditionally, but the USE_QUASAR_ALPHA guard is only applied to the pointer offset adjustment at lines 84-85. When beta_out is None, this will cause a runtime error.

🐛 Proposed fix: Guard the store with USE_QUASAR_ALPHA
-    p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
-    tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
+    if USE_QUASAR_ALPHA:
+        p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
+        tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
if USE_QUASAR_ALPHA:
p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_intra_token_parallel.py` around lines 109 - 110, The
code unconditionally creates p_beta_out and calls tl.store into beta_out even
when beta_out can be None; wrap the creation of p_beta_out and the tl.store call
in the same USE_QUASAR_ALPHA guard (or check beta_out is not None) so the
pointer is only made and written when beta_out exists; specifically, guard the
block that references p_beta_out and b_alpha (the tl.make_block_ptr and tl.store
lines) with the USE_QUASAR_ALPHA condition used earlier so no write occurs when
beta_out is None.

Comment thread fla/ops/quasar/chunk.py
Comment on lines +203 to +217
# Recompute alpha
eps = 1e-6
k_norm_sq = (k.float() * k.float()).sum(dim=-1) # [B, T, H]
k_norm_sq = torch.clamp(k_norm_sq, min=0.1, max=10.0)

if beta.dim() == 1:
beta_h = beta.view(1, 1, H).to(k_norm_sq.dtype)
else:
beta_h = beta.to(k_norm_sq.dtype)

beta_h = torch.clamp(beta_h, min=0.01, max=10.0)
# Compute alpha with numerical stability
exp_term = torch.exp(-beta_h * k_norm_sq)
alpha = (1.0 - exp_term) / (k_norm_sq + eps)
beta_tok = alpha.to(dtype=q.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Compare forward and backward alpha formulas
echo "=== Forward alpha computation ==="
rg -n "alpha|decay|softplus" fla/ops/quasar/chunk.py | head -20

echo ""
echo "=== Backward alpha computation ==="
rg -n "exp_term|alpha" fla/ops/quasar/chunk.py | tail -20

Repository: fla-org/flash-linear-attention

Length of output: 2600


🏁 Script executed:

# Read the forward pass around lines 84-95
sed -n '80,100p' fla/ops/quasar/chunk.py

Repository: fla-org/flash-linear-attention

Length of output: 998


🏁 Script executed:

# Read lines around 187 to see what is saved
sed -n '180,195p' fla/ops/quasar/chunk.py

Repository: fla-org/flash-linear-attention

Length of output: 586


🏁 Script executed:

# Read the full backward alpha recomputation section
sed -n '200,240p' fla/ops/quasar/chunk.py

Repository: fla-org/flash-linear-attention

Length of output: 1421


🏁 Script executed:

# Search for usage of A_log and dt_bias in the entire file
rg -n "A_log|dt_bias" fla/ops/quasar/chunk.py

Repository: fla-org/flash-linear-attention

Length of output: 1514


🏁 Script executed:

# Get the full backward method to see if A_log and dt_bias are used
sed -n '195,350p' fla/ops/quasar/chunk.py

Repository: fla-org/flash-linear-attention

Length of output: 5276


🏁 Script executed:

# Search for any usage of A_log or dt_bias after line 197 in the backward section
sed -n '197,400p' fla/ops/quasar/chunk.py | grep -n "A_log\|dt_bias"

Repository: fla-org/flash-linear-attention

Length of output: 532


Backward alpha recomputation uses different formula than forward, and saved parameters are unused.

The forward pass (lines 84-91) computes alpha using:

g_quasar = -A * softplus(beta_expanded + dt_bias_full)
decay = exp(g_quasar)
alpha = (1.0 - decay) / (k_norm_sq_expanded + 1e-6)

But the backward pass (lines 215-216) uses a simplified formula that omits A_log and dt_bias:

exp_term = exp(-beta_h * k_norm_sq)
alpha = (1.0 - exp_term) / (k_norm_sq + eps)

Additionally, A_log and dt_bias are saved at line 187 and unpacked at line 197, but never used anywhere in the backward method. The gradient computation for beta (lines 326, 333) uses dalpha_dbeta = k_norm_sq * exp_term / (k_norm_sq + eps), which derives from the simplified formula, not the actual forward formula. This causes the computed gradients to be based on a different alpha formula than what was actually used in the forward pass.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk.py` around lines 203 - 217, The backward pass must
reconstruct the exact forward alpha formula using the saved A_log and dt_bias
and compute gradients via the chain rule; replace the simplified exp(-beta_h *
k_norm_sq) path in the backward method with the forward-equivalent steps:
recover A = exp(A_log) and expand beta to beta_expanded and dt_bias to
dt_bias_full to match k_norm_sq shape, compute g_quasar = -A *
softplus(beta_expanded + dt_bias_full), set decay = exp(g_quasar) and alpha =
(1.0 - decay) / (k_norm_sq + eps) (same as forward), then compute dalpha_dbeta
using dalpha/dg * dg/dbeta = ( -decay / (k_norm_sq+eps) ) * ( -A *
sigmoid(beta_expanded + dt_bias_full) ) => A * decay * sigmoid(...) /
(k_norm_sq+eps) and use that in the existing beta gradient calculations (replace
the current dalpha_dbeta that uses k_norm_sq * exp_term / (k_norm_sq + eps));
ensure you use the saved tensors A_log and dt_bias and the same eps, softplus,
and shapes as in forward (variables: A_log, dt_bias, beta_h/beta_expanded,
g_quasar, decay, alpha).

Comment on lines +107 to +119
ctx.save_for_backward(L, A)
return A

@staticmethod
@input_guard
def backward(ctx, dA):
L, A = ctx.saved_tensors

# Backward pass: dL = -A^T @ dA @ A^T
# Simplified implementation for now
dL = torch.zeros_like(L)

return dL
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Please don't ship a zero gradient from this custom backward.

The comment on Line 115 already states the right VJP form, but the implementation returns zeros_like(L). That makes quasar_forward_substitution() look differentiable while silently blocking every upstream update that flows through L. If this helper is inference/recompute-only for now, raising from backward() is much safer than returning a valid-looking zero tensor.

🧮 Proposed fix
     def forward(
         ctx,
         L: torch.Tensor,
     ):
         A = forward_substitution(L)
-        ctx.save_for_backward(L, A)
+        ctx.save_for_backward(A)
         return A
@@
     def backward(ctx, dA):
-        L, A = ctx.saved_tensors
-        
-        # Backward pass: dL = -A^T @ dA @ A^T
-        # Simplified implementation for now
-        dL = torch.zeros_like(L)
-        
+        (A,) = ctx.saved_tensors
+        gA = torch.tril(dA, diagonal=-1)
+        dL = -(A.transpose(-1, -2) @ gA @ A.transpose(-1, -2))
+        dL = torch.tril(dL, diagonal=-1)
         return dL
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 113-113: Unpacked variable A is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/forward_substitution.py` around lines 107 - 119, The backward
implementation for quasar_forward_substitution currently returns zeros_like(L)
which silently blocks gradients; either implement the correct VJP using the
saved L and A (compute dL = -A.transpose(-2,-1) @ dA @ A.transpose(-2,-1) with
appropriate tensor shapes and returns) or, if this op is
inference/recompute-only, replace the zero return with an explicit error (e.g.
raise RuntimeError) inside the static backward(ctx, dA) of the custom autograd
Function to prevent silent gradient masking; locate the backward method where
ctx.saved_tensors yields L and A and update that method accordingly.

Comment on lines +76 to +78
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check if QuasarAttention normalizes q/k before calling fused_recurrent_quasar

# Search for F.normalize calls followed by fused_recurrent_quasar with use_qk_l2norm_in_kernel=True
rg -n -A 30 'F\.normalize.*p=2' --type py | rg -A 20 'fused_recurrent_quasar'

Repository: fla-org/flash-linear-attention

Length of output: 590


🏁 Script executed:

# Get line count and read the relevant section in quasar.py
wc -l fla/layers/quasar.py

Repository: fla-org/flash-linear-attention

Length of output: 98


🏁 Script executed:

# Read the QuasarAttention.forward method around lines 265-295
sed -n '260,300p' fla/layers/quasar.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 1869


🏁 Script executed:

# Read the kernel normalization code in fused_recurrent.py
sed -n '70,85p' fla/ops/quasar/fused_recurrent.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 753


🏁 Script executed:

# Search for the use_qk_l2norm_in_kernel parameter being passed
rg -n 'use_qk_l2norm_in_kernel' fla/layers/quasar.py

Repository: fla-org/flash-linear-attention

Length of output: 173


Remove redundant L2-normalization in kernel when vectors are already normalized.

The code normalizes q and k with F.normalize(q, p=2, dim=-1) and F.normalize(k, p=2, dim=-1) (lines 264-266), then passes them to fused_recurrent_quasar with use_qk_l2norm_in_kernel=True. The kernel then applies L2-normalization again:

if USE_QK_L2NORM_IN_KERNEL:
    b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
    b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)

This results in double normalization. Either remove the pre-normalization in QuasarAttention.forward when using kernel normalization, or pass use_qk_l2norm_in_kernel=False to avoid the redundant kernel normalization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/fused_recurrent.py` around lines 76 - 78, The kernel is
L2-normalizing b_q and b_k (controlled by USE_QK_L2NORM_IN_KERNEL) while
QuasarAttention.forward already normalizes q and k with F.normalize; to avoid
double normalization, update QuasarAttention.forward so that when calling
fused_recurrent_quasar it either does not pre-normalize (remove the
F.normalize(q, p=2, dim=-1)/F.normalize(k, p=2, dim=-1) calls) if you want the
kernel to handle normalization, or pass use_qk_l2norm_in_kernel=False to
fused_recurrent_quasar so the kernel skips its own normalization; adjust the
call site in QuasarAttention.forward and ensure consistency with the
USE_QK_L2NORM_IN_KERNEL flag and any related parameters.

Comment on lines +120 to +126
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))

if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Duplicate STORE_FINAL_STATE block.

The final state storage logic is duplicated at lines 120-122 and 124-126. This causes the final state to be written twice unnecessarily.

🐛 Proposed fix: Remove duplicate block
     if STORE_FINAL_STATE:
         p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
         tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
-        
-    if STORE_FINAL_STATE:
-        p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
-        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/fused_recurrent.py` around lines 120 - 126, The code contains
a duplicated STORE_FINAL_STATE block that computes p_ht and calls tl.store
twice; remove the redundant second block to ensure final_state is stored only
once. Locate the repeated conditional using STORE_FINAL_STATE that computes p_ht
from final_state, i_b, H, i_h, BK, o_k, o_v and calls tl.store(p_ht,
b_h.to(p_ht.dtype.element_ty)), and delete the duplicate so only a single
STORE_FINAL_STATE branch performs the tl.store of b_h.

Comment thread fla/ops/quasar/gate.py
Comment on lines +125 to +145
def backward(ctx, dalpha: torch.Tensor):
lambda_t, beta = ctx.saved_tensors
eps = 1e-8

# dalpha/dlambda and dalpha/dbeta derivatives
# alpha = (1 - exp(-beta * lambda)) / (lambda + eps)
# dalpha/dbeta = exp(-beta * lambda)
beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t)
lambda_plus_eps = lambda_t + eps

# dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2
dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)

# dalpha/dbeta = exp(-beta * lambda)
dbeta = beta_exp

dlambda = dlambda * dalpha
# Sum over sequence and dimensions, but preserve head dimension
dbeta = (dbeta * dalpha).sum(dim=(0, 1))

return dlambda, dbeta, None, None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify the custom autograd forward/backward contract in fla/ops/quasar/gate.py.
python - <<'PY'
import ast
from pathlib import Path

path = Path("fla/ops/quasar/gate.py")
tree = ast.parse(path.read_text())

for node in tree.body:
    if isinstance(node, ast.ClassDef) and node.name == "QuasarGateFunction":
        methods = {m.name: m for m in node.body if isinstance(m, ast.FunctionDef)}
        fwd = methods["forward"]
        bwd = methods["backward"]

        forward_inputs = [arg.arg for arg in fwd.args.args[1:]]  # exclude ctx
        print("forward inputs:", forward_inputs)

        for stmt in bwd.body:
            if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Tuple):
                print("backward return count:", len(stmt.value.elts))
                print("backward return exprs:", [ast.unparse(e) for e in stmt.value.elts])
PY

Repository: fla-org/flash-linear-attention

Length of output: 211


🏁 Script executed:

cat -n fla/ops/quasar/gate.py | head -150

Repository: fla-org/flash-linear-attention

Length of output: 5423


Fix QuasarGateFunction.backward() before using in training.

forward() accepts three inputs, but backward() returns four gradients on line 145, violating torch.autograd.Function's contract. Additionally, dbeta omits the required lambda_t / (lambda_t + eps) factor in its derivative, and the hardcoded .sum(dim=(0, 1)) reduction incorrectly assumes 3D tensors and leaves a singleton dimension instead of matching beta.shape.

🐛 Proposed fix
     def backward(ctx, dalpha: torch.Tensor):
         lambda_t, beta = ctx.saved_tensors
         eps = 1e-8
+        beta_view = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1)
 
-        # dalpha/dlambda and dalpha/dbeta derivatives
-        # alpha = (1 - exp(-beta * lambda)) / (lambda + eps)
-        # dalpha/dbeta = exp(-beta * lambda)
-        beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t)
+        beta_exp = torch.exp(-beta_view * lambda_t)
         lambda_plus_eps = lambda_t + eps
 
-        # dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2
-        dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
-
-        # dalpha/dbeta = exp(-beta * lambda)
-        dbeta = beta_exp
-
+        dlambda = (beta_view * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
         dlambda = dlambda * dalpha
-        # Sum over sequence and dimensions, but preserve head dimension
-        dbeta = (dbeta * dalpha).sum(dim=(0, 1))
-
-        return dlambda, dbeta, None, None
+        dbeta = (lambda_t * beta_exp / lambda_plus_eps) * dalpha
+        reduce_dims = tuple(i for i in range(dbeta.ndim) if i != dbeta.ndim - 2)
+        dbeta = dbeta.sum(dim=reduce_dims)
+
+        return dlambda, dbeta, None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def backward(ctx, dalpha: torch.Tensor):
lambda_t, beta = ctx.saved_tensors
eps = 1e-8
# dalpha/dlambda and dalpha/dbeta derivatives
# alpha = (1 - exp(-beta * lambda)) / (lambda + eps)
# dalpha/dbeta = exp(-beta * lambda)
beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t)
lambda_plus_eps = lambda_t + eps
# dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2
dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
# dalpha/dbeta = exp(-beta * lambda)
dbeta = beta_exp
dlambda = dlambda * dalpha
# Sum over sequence and dimensions, but preserve head dimension
dbeta = (dbeta * dalpha).sum(dim=(0, 1))
return dlambda, dbeta, None, None
def backward(ctx, dalpha: torch.Tensor):
lambda_t, beta = ctx.saved_tensors
eps = 1e-8
beta_view = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1)
beta_exp = torch.exp(-beta_view * lambda_t)
lambda_plus_eps = lambda_t + eps
dlambda = (beta_view * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
dlambda = dlambda * dalpha
dbeta = (lambda_t * beta_exp / lambda_plus_eps) * dalpha
reduce_dims = tuple(i for i in range(dbeta.ndim) if i != dbeta.ndim - 2)
dbeta = dbeta.sum(dim=reduce_dims)
return dlambda, dbeta, None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/gate.py` around lines 125 - 145, QuasarGateFunction.backward
currently returns four gradients despite forward taking three inputs and
computes dbeta incorrectly and with a wrong reduction; update backward in
QuasarGateFunction to (1) return exactly three gradients to match forward's
inputs, (2) compute dbeta = exp(-beta.view(...)*lambda_t) * dalpha * (lambda_t /
(lambda_t + eps)) (i.e., multiply the existing dbeta term by
lambda_t/(lambda_t+eps)), and (3) replace the hardcoded .sum(dim=(0,1)) with a
reduction that sums over all dimensions of dalpha except the beta dimension so
the resulting dbeta matches beta.shape (similarly sum dlambda over the
non-lambda dimensions so dlambda matches lambda_t.shape); use ctx.saved_tensors
(lambda_t, beta) to determine which dims to reduce.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
fla/ops/quasar/chunk_bwd.py (2)

11-20: ⚠️ Potential issue | 🟠 Major

Guard safe_dot for non-NVIDIA backends.

Line 13 uses NVIDIA PTX inline asm without a backend fallback. This can break AMD execution paths.

🐛 Proposed fix
 `@triton.jit`
 def safe_dot(a, b):
-    return tl.inline_asm_elementwise(
-        asm="mov.f32 $0, $1;",
-        constraints="=r,r",
-        args=[tl.dot(a, b)],
-        dtype=tl.float32,
-        is_pure=True,
-        pack=1,
-    )
+    if IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL:
+        return tl.inline_asm_elementwise(
+            asm="mov.f32 $0, $1;",
+            constraints="=r,r",
+            args=[tl.dot(a, b)],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+    return tl.dot(a, b)
#!/bin/bash
# Verify NVIDIA-specific inline PTX is present and currently unguarded in safe_dot.
rg -n -C2 'def safe_dot|inline_asm_elementwise|mov\.f32|IS_NVIDIA_HOPPER|IS_NVIDIA_BLACKWELL|IS_AMD' fla/ops/quasar/chunk_bwd.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, The safe_dot function uses
NVIDIA PTX via tl.inline_asm_elementwise (asm="mov.f32 ...") with no backend
guard, which will break non‑NVIDIA backends; update safe_dot to detect the
backend (e.g., use existing flags like IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL
or a runtime check) and only call tl.inline_asm_elementwise when on an NVIDIA
backend, otherwise fall back to a pure Triton implementation (e.g., return
tl.dot(a, b) or equivalent) so that tl.inline_asm_elementwise is never invoked
on AMD/other backends.

275-279: ⚠️ Potential issue | 🔴 Critical

Fix callable misuse in shared-memory branch selection.

Line 277 checks the function object instead of calling it, so this branch is always truthy when Line 275 is false.

🐛 Proposed fix
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem(device_index=k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 275 - 279, The elif branch is
testing the function object instead of calling it, so replace the incorrect
check with a proper call to check_shared_mem using the same arguments as the
first branch (i.e., call check_shared_mem('hopper', k.device.index)); update the
branch that sets CONST_TILING (and ensure CONST_TILING remains assigned 64 when
the call returns True) to use check_shared_mem('hopper', k.device.index) instead
of check_shared_mem so the shared-memory detection logic using check_shared_mem
and k.device.index works as intended.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 188-190: The loop currently defines unused locals m_k and o_k
inside the inner K-loop (variables m_k and o_k in the for i_k in
range(tl.cdiv(K, BK)) block), causing a lint F841; remove m_k and o_k and update
any downstream uses to either compute the required index expression inline
(e.g., i_k * BK + tl.arange(0, BK)) or eliminate the computation entirely if
it’s not used; ensure the loop body only keeps the necessary expressions and
that BK, K, and i_k remain correct for any remaining logic in chunk_bwd.py.

---

Duplicate comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 11-20: The safe_dot function uses NVIDIA PTX via
tl.inline_asm_elementwise (asm="mov.f32 ...") with no backend guard, which will
break non‑NVIDIA backends; update safe_dot to detect the backend (e.g., use
existing flags like IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL or a runtime check)
and only call tl.inline_asm_elementwise when on an NVIDIA backend, otherwise
fall back to a pure Triton implementation (e.g., return tl.dot(a, b) or
equivalent) so that tl.inline_asm_elementwise is never invoked on AMD/other
backends.
- Around line 275-279: The elif branch is testing the function object instead of
calling it, so replace the incorrect check with a proper call to
check_shared_mem using the same arguments as the first branch (i.e., call
check_shared_mem('hopper', k.device.index)); update the branch that sets
CONST_TILING (and ensure CONST_TILING remains assigned 64 when the call returns
True) to use check_shared_mem('hopper', k.device.index) instead of
check_shared_mem so the shared-memory detection logic using check_shared_mem and
k.device.index works as intended.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a5eb4037-cb62-4130-b919-e4ed3d0ef7b6

📥 Commits

Reviewing files that changed from the base of the PR and between 1d92a63 and a236bff.

📒 Files selected for processing (1)
  • fla/ops/quasar/chunk_bwd.py

Comment thread fla/ops/quasar/chunk_bwd.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (2)
fla/ops/quasar/chunk_bwd.py (2)

272-277: ⚠️ Potential issue | 🔴 Critical

Call check_shared_mem(...) here instead of testing the function object.

elif check_shared_mem: is always truthy, so non-Hopper devices always end up with CONST_TILING = 64. That bypasses the actual shared-memory check and can select an invalid tile size.

🛠️ Suggested change
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem(device_index=k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 272 - 277, The elif is testing the
function object instead of calling it, so CONST_TILING is set to 64 for
non-Hopper devices regardless of actual shared memory — call check_shared_mem
with the same device identifier used in the first branch (e.g.,
check_shared_mem('hopper', k.device.index) vs check_shared_mem('some_tag',
k.device.index) or simply check_shared_mem(k.device.index) depending on the
function signature) to perform the real shared-memory test; update the elif to
call check_shared_mem(...) (and pass the appropriate 'hopper' or device args
consistent with the first call) so CONST_TILING is chosen based on the actual
check_shared_mem result, not the function object.

11-20: ⚠️ Potential issue | 🟠 Major

Guard the PTX-only safe_dot path.

tl.inline_asm_elementwise here hardcodes a PTX mov.f32, so this kernel only compiles on NVIDIA backends. Without an AMD/ROCm guard or a plain Triton fallback, the new backward path is not portable across the backends FLA supports.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, safe_dot currently
hardcodes PTX via tl.inline_asm_elementwise with asm="mov.f32 ..." which only
works on NVIDIA; wrap that PTX-specific call in a guarded path and provide a
portable fallback (e.g., return tl.dot(a, b) or tl.move-equivalent) when PTX
inline asm fails or when the backend is not NVIDIA. Concretely, update safe_dot
to attempt the tl.inline_asm_elementwise(asm="mov.f32 ...", constraints="=r,r",
args=[tl.dot(a,b)], ...) inside a try/except (or behind a runtime check for
NVIDIA) and on exception or non-NVIDIA return the plain Triton expression
tl.dot(a, b) so the function works across AMD/ROCm and other backends. Ensure
references to safe_dot, tl.inline_asm_elementwise, and the asm string are used
so reviewers can find the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/layers/quasar.py`:
- Around line 6-18: The import list in fla.layers.quasar.py includes unused
symbols (math, repeat from einops, RMSNorm from fla.modules, and
fused_quasar_gate from fla.ops.quasar.gate) causing F401; remove these unused
imports from the top of the module so only used names (contextlib, torch, nn,
rearrange, F, get_unpad_data, index_first_axis, pad_input, FusedRMSNormGated,
ShortConvolution, chunk_quasar, fused_recurrent_quasar) remain; locate and edit
the import block at the top of quasar.py to drop the four unused identifiers.
- Around line 199-205: The no-unpadding fast path fails because when
attention_mask.all() is true we set indices=None but still run the
RoPE/unpad-dependent logic and pad_input with None; modify the branches that
call get_unpad_data, index_first_axis, the RoPE handling, and pad_input to only
execute when unpadding actually occurred (i.e., indices is not None or a
did_unpad flag is true). Concretely, in the blocks around attention_mask, the
get_unpad_data call that sets indices, cu_seqlens must gate subsequent uses
(index_first_axis, rearrange-hidden_states handling, RoPE code, and pad_input)
on indices is not None (or set a boolean did_unpad and check it) so dense masks
(attention_mask.all()) skip all unpadding/padding and RoPE branches.
- Around line 171-174: Don't unconditionally overwrite the local mode variable;
use the layer's configured mode (self.mode) instead of forcing mode = "chunk",
so the fused_recurrent branch can run in eval. Set mode = self.mode (or fall
back to "chunk" if absent), keep the training-time guard assert to enforce that
training requires "chunk" (i.e., assert self.mode == "chunk" when
self.training), and leave the fused_recurrent branch to select kernels based on
this mode. Reference: self.mode, mode variable, self.training, and the
fused_recurrent branch in fla/layers/quasar.py.

In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 262-263: The wrapper functions (e.g., chunk_quasar_bwd_dAv)
currently declare A: torch.Tensor | None = None and scale: float = None but pass
them directly into kernels that unconditionally read A and multiply by scale;
remove the spurious None defaults or normalize/validate them before any kernel
launch by making A and scale required parameters (no default) or adding an
explicit check at the start of the wrapper (raise a clear error if A is None or
scale is None) or set a valid default value for scale if intended; update any
other wrapper that uses the same pattern (the other launcher referenced around
the same area) to perform the same validation so kernels never receive None for
A or scale.

---

Duplicate comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 272-277: The elif is testing the function object instead of
calling it, so CONST_TILING is set to 64 for non-Hopper devices regardless of
actual shared memory — call check_shared_mem with the same device identifier
used in the first branch (e.g., check_shared_mem('hopper', k.device.index) vs
check_shared_mem('some_tag', k.device.index) or simply
check_shared_mem(k.device.index) depending on the function signature) to perform
the real shared-memory test; update the elif to call check_shared_mem(...) (and
pass the appropriate 'hopper' or device args consistent with the first call) so
CONST_TILING is chosen based on the actual check_shared_mem result, not the
function object.
- Around line 11-20: safe_dot currently hardcodes PTX via
tl.inline_asm_elementwise with asm="mov.f32 ..." which only works on NVIDIA;
wrap that PTX-specific call in a guarded path and provide a portable fallback
(e.g., return tl.dot(a, b) or tl.move-equivalent) when PTX inline asm fails or
when the backend is not NVIDIA. Concretely, update safe_dot to attempt the
tl.inline_asm_elementwise(asm="mov.f32 ...", constraints="=r,r",
args=[tl.dot(a,b)], ...) inside a try/except (or behind a runtime check for
NVIDIA) and on exception or non-NVIDIA return the plain Triton expression
tl.dot(a, b) so the function works across AMD/ROCm and other backends. Ensure
references to safe_dot, tl.inline_asm_elementwise, and the asm string are used
so reviewers can find the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 28a852dc-bfb2-417f-88ad-77ce8bde2fcc

📥 Commits

Reviewing files that changed from the base of the PR and between a236bff and 84ad1cc.

📒 Files selected for processing (2)
  • fla/layers/quasar.py
  • fla/ops/quasar/chunk_bwd.py

Comment thread fla/layers/quasar.py
Comment on lines +6 to +18
import contextlib
import math
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch.nn import functional as F

from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
from fla.ops.quasar import chunk_quasar, fused_recurrent_quasar
from fla.ops.quasar.gate import fused_quasar_gate
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clean up the unused imports before merge.

math, repeat, RMSNorm, and fused_quasar_gate are still unused in this module, so Flake8 will keep reporting F401 here.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 7-7: 'math' imported but unused

(F401)


[error] 12-12: 'einops.repeat' imported but unused

(F401)


[error] 16-16: 'fla.modules.RMSNorm' imported but unused

(F401)


[error] 18-18: 'fla.ops.quasar.gate.fused_quasar_gate' imported but unused

(F401)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 6 - 18, The import list in
fla.layers.quasar.py includes unused symbols (math, repeat from einops, RMSNorm
from fla.modules, and fused_quasar_gate from fla.ops.quasar.gate) causing F401;
remove these unused imports from the top of the module so only used names
(contextlib, torch, nn, rearrange, F, get_unpad_data, index_first_axis,
pad_input, FusedRMSNormGated, ShortConvolution, chunk_quasar,
fused_recurrent_quasar) remain; locate and edit the import block at the top of
quasar.py to drop the four unused identifiers.

Comment thread fla/layers/quasar.py
Comment on lines +171 to +174
# Force chunk mode to avoid fused_recurrent BT conflict
mode = "chunk"
if self.training:
assert mode == "chunk", "Only chunk mode is supported in training."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

self.mode is ignored here.

mode is overwritten to "chunk" unconditionally, so the fused_recurrent branch below is dead even in eval. That silently ignores the layer config and defeats the decode-path kernel selection.

🛠️ Suggested change
-        # Force chunk mode to avoid fused_recurrent BT conflict
-        mode = "chunk"
-        if self.training:
-            assert mode == "chunk", "Only chunk mode is supported in training."
+        mode = "chunk" if self.training else self.mode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Force chunk mode to avoid fused_recurrent BT conflict
mode = "chunk"
if self.training:
assert mode == "chunk", "Only chunk mode is supported in training."
# Force chunk mode to avoid fused_recurrent BT conflict
mode = "chunk" if self.training else self.mode
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 171 - 174, Don't unconditionally overwrite
the local mode variable; use the layer's configured mode (self.mode) instead of
forcing mode = "chunk", so the fused_recurrent branch can run in eval. Set mode
= self.mode (or fall back to "chunk" if absent), keep the training-time guard
assert to enforce that training requires "chunk" (i.e., assert self.mode ==
"chunk" when self.training), and leave the fused_recurrent branch to select
kernels based on this mode. Reference: self.mode, mode variable, self.training,
and the fused_recurrent branch in fla/layers/quasar.py.

Comment thread fla/layers/quasar.py
Comment on lines +199 to +205
if attention_mask is not None:
# Optimization: Skip unpadding if all tokens are valid (common in packed distillation)
if attention_mask.all():
indices, cu_seqlens = None, None
else:
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Dense masks break the no-unpadding fast path.

When attention_mask.all() is true, indices stays None, but the RoPE branch and the final pad_input(...) branch still execute solely because attention_mask is not None. That sends None into index_first_axis(...)/pad_input(...) and breaks the common “all tokens valid” path. Gate those branches on indices is not None (or an explicit did_unpad flag) instead.

Also applies to: 241-257, 333-334

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/quasar.py` around lines 199 - 205, The no-unpadding fast path
fails because when attention_mask.all() is true we set indices=None but still
run the RoPE/unpad-dependent logic and pad_input with None; modify the branches
that call get_unpad_data, index_first_axis, the RoPE handling, and pad_input to
only execute when unpadding actually occurred (i.e., indices is not None or a
did_unpad flag is true). Concretely, in the blocks around attention_mask, the
get_unpad_data call that sets indices, cu_seqlens must gate subsequent uses
(index_first_axis, rearrange-hidden_states handling, RoPE code, and pad_input)
on indices is not None (or set a boolean did_unpad and check it) so dense masks
(attention_mask.all()) skip all unpadding/padding and RoPE branches.

Comment on lines +262 to +263
A: torch.Tensor | None = None,
scale: float = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

The wrapper defaults advertise unsupported calls.

Both launchers pass scale straight into kernels that multiply by it, and chunk_quasar_bwd_dAv also forwards A to a kernel that unconditionally loads from it. None is not a real default here, so the next direct caller gets a launch-time failure unless these parameters are made required or normalized before launch.

🛠️ Suggested change
 def chunk_quasar_bwd_dAv(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     do: torch.Tensor,
-    A: torch.Tensor | None = None,
-    scale: float = None,
+    A: torch.Tensor,
+    scale: float,
     cu_seqlens: torch.LongTensor | None = None,
     chunk_size: int = 64,
     chunk_indices: torch.LongTensor | None = None,
 ) -> tuple[torch.Tensor, torch.Tensor]:
@@
 def chunk_quasar_bwd_wy_dqkb_fused(
@@
-    scale: float | None = None,
+    scale: float,
     cu_seqlens: torch.LongTensor | None = None,
     chunk_size: int = 64,
     chunk_indices: torch.LongTensor | None = None,
 ):

Also applies to: 318-318

🧰 Tools
🪛 Ruff (0.15.7)

[warning] 263-263: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/quasar/chunk_bwd.py` around lines 262 - 263, The wrapper functions
(e.g., chunk_quasar_bwd_dAv) currently declare A: torch.Tensor | None = None and
scale: float = None but pass them directly into kernels that unconditionally
read A and multiply by scale; remove the spurious None defaults or
normalize/validate them before any kernel launch by making A and scale required
parameters (no default) or adding an explicit check at the start of the wrapper
(raise a clear error if A is None or scale is None) or set a valid default value
for scale if intended; update any other wrapper that uses the same pattern (the
other launcher referenced around the same area) to perform the same validation
so kernels never receive None for A or scale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant